/*
* (C) 2023 Jack Lloyd
*     2023 Fabian Albert, René Meusel, Amos Treiber - Rohde & Schwarz Cybersecurity
*
* Botan is released under the Simplified BSD License (see license.txt)
*/

#include "test_rng.h"
#include "tests.h"

#if defined(BOTAN_HAS_SPHINCS_PLUS_COMMON) && defined(BOTAN_HAS_AES)

   #include <botan/assert.h>
   #include <botan/hash.h>
   #include <botan/pk_algs.h>
   #include <botan/pubkey.h>
   #include <botan/secmem.h>
   #include <botan/sp_parameters.h>
   #include <botan/sphincsplus.h>
   #include <botan/internal/loadstor.h>
   #include <botan/internal/sp_hash.h>
   #include <algorithm>

   #include "test_pubkey.h"

namespace Botan_Tests {

/**
 * Test all implemented SLH-DSA instances using the data of the KAT files.
 */
class SPHINCS_Plus_Test_Base : public Text_Based_Test {
   public:
      explicit SPHINCS_Plus_Test_Base(std::string_view kat_path) :
            Text_Based_Test(std::string(kat_path), "SphincsParameterSet,seed,pk,sk,msg,HashSigRand", "HashSigDet") {}

      bool skip_this_test(const std::string& /*header*/, const VarMap& vars) override {
         auto params = Botan::Sphincs_Parameters::create(vars.get_req_str("SphincsParameterSet"));

         if(!params.is_available()) {
            return true;
         }

   #if not defined(BOTAN_HAS_SHA3)
         // The SLH-DSA shake-based signatures are hashed with SHA-3 in the test file.
         if(params.hash_type() == Botan::Sphincs_Hash_Type::Shake256) {
            return true;
         }
   #endif

         // Execute the small (slow) instances only with --run-long-tests
         switch(params.parameter_set()) {
            case Botan::Sphincs_Parameter_Set::Sphincs128Fast:
            case Botan::Sphincs_Parameter_Set::Sphincs192Fast:
            case Botan::Sphincs_Parameter_Set::Sphincs256Fast:
            case Botan::Sphincs_Parameter_Set::SLHDSA128Fast:
            case Botan::Sphincs_Parameter_Set::SLHDSA192Fast:
            case Botan::Sphincs_Parameter_Set::SLHDSA256Fast:
               return false;
            case Botan::Sphincs_Parameter_Set::Sphincs128Small:
            case Botan::Sphincs_Parameter_Set::Sphincs192Small:
            case Botan::Sphincs_Parameter_Set::Sphincs256Small:
            case Botan::Sphincs_Parameter_Set::SLHDSA128Small:
            case Botan::Sphincs_Parameter_Set::SLHDSA192Small:
            case Botan::Sphincs_Parameter_Set::SLHDSA256Small:
               return !Test::run_long_tests();
         }
         BOTAN_ASSERT_UNREACHABLE();
      }

      Test::Result run_one_test(const std::string& /*header*/, const VarMap& vars) final {
         auto params = Botan::Sphincs_Parameters::create(vars.get_req_str("SphincsParameterSet"));
         Test::Result result(params.is_slh_dsa() ? "SLH-DSA" : "SPHINCS+");

         const std::vector<uint8_t> seed_ref = vars.get_req_bin("seed");
         const std::vector<uint8_t> msg_ref = vars.get_req_bin("msg");
         const std::vector<uint8_t> pk_ref = vars.get_req_bin("pk");
         const std::vector<uint8_t> sk_ref = vars.get_req_bin("sk");
         const std::vector<uint8_t> sig_rand_hash = vars.get_req_bin("HashSigRand");
         const auto sig_det_hash = [&]() -> std::optional<std::vector<uint8_t>> {
            if(vars.has_key("HashSigDet")) {
               return vars.get_opt_bin("HashSigDet");
            } else {
               return std::nullopt;
            }
         }();

         // Depending on the SLH-DSA configuration the resulting signature is
         // hashed either with SHA-3 or SHA-256 to reduce the inner dependencies
         // on other hash function modules.
         auto hash_algo_spec = [&]() -> std::string {
            if(params.hash_type() == Botan::Sphincs_Hash_Type::Shake256) {
               return "SHA-3(256)";
            } else {
               return "SHA-256";
            }
         }();
         auto hash = Botan::HashFunction::create_or_throw(hash_algo_spec);

         /*
          * To get sk_seed || sk_prf || pk_seed and opt_rand from the given seed
          * (from KAT), we create a CTR_DRBG_AES256 rng and "simulate" the
          * invocation-pattern in the reference. We feed the created randomness
          * to the fixed output rng, to allow for our (slightly different)
          * invocation-pattern. Be careful, some KATs are generated by 3 separate
          * invocations of dbrg.random_vals which is different.
          */
         auto kat_rng = CTR_DRBG_AES256(seed_ref);
         Fixed_Output_RNG fixed_rng;
         fixed_rng.add_entropy(kat_rng.random_vec<std::vector<uint8_t>>(3 * params.n()));

         // push the entropy used for signing twice, as we want to perform two
         // signing operations
         const auto entropy_for_signing = kat_rng.random_vec<std::vector<uint8_t>>(1 * params.n());
         // Depending on the configuation, upto 2 signatures with 'Randomized' are created
         fixed_rng.add_entropy(entropy_for_signing);
         fixed_rng.add_entropy(entropy_for_signing);

         // Generate Keypair
         Botan::SphincsPlus_PrivateKey priv_key(fixed_rng, params);

         result.test_is_eq("public key bits", priv_key.public_key_bits(), pk_ref);
         result.test_is_eq("private key bits", unlock(priv_key.private_key_bits()), sk_ref);

         // Signature roundtrip (Randomized mode)
         auto signer_rand = Botan::PK_Signer(priv_key, fixed_rng, "Randomized");
         auto signature_rand = signer_rand.sign_message(msg_ref.data(), msg_ref.size(), fixed_rng);

         result.test_is_eq("signature creation randomized", unlock(hash->process(signature_rand)), sig_rand_hash);

         Botan::PK_Verifier verifier(*priv_key.public_key(), params.algorithm_identifier());
         bool verify_success =
            verifier.verify_message(msg_ref.data(), msg_ref.size(), signature_rand.data(), signature_rand.size());
         result.confirm("verification of valid randomized signature", verify_success);

         // Signature roundtrip (Deterministic mode) - not available for all parameter sets
         // For testing time reasons we only test this for some tests if not --run-long-tests
         if(sig_det_hash.has_value() &&
            (run_long_tests() || params.parameter_set() == Botan::Sphincs_Parameter_Set::SLHDSA128Fast)) {
            auto signer_det = Botan::PK_Signer(priv_key, fixed_rng, "Deterministic");
            auto signature_det = signer_det.sign_message(msg_ref.data(), msg_ref.size(), fixed_rng);

            result.test_is_eq("signature creation deterministic", unlock(hash->process(signature_det)), *sig_det_hash);

            auto verify_success_det =
               verifier.verify_message(msg_ref.data(), msg_ref.size(), signature_det.data(), signature_det.size());
            result.confirm("verification of valid deterministic signature", verify_success_det);
         }

         // Verification with generated Keypair

         // Run additional parsing and re-verification tests on one parameter
         // set only to save test runtime. Given the current test vector we will
         // run this exactly once per hash function.
         if(params.parameter_set() == Botan::Sphincs_Parameter_Set::Sphincs128Fast ||
            params.parameter_set() == Botan::Sphincs_Parameter_Set::SLHDSA128Fast) {
            // Deserialization of Keypair from test vector
            Botan::SphincsPlus_PrivateKey deserialized_priv_key(sk_ref, params);
            Botan::SphincsPlus_PublicKey deserialized_pub_key(pk_ref, params);

            // Signature with deserialized Keypair
            auto deserialized_signer = Botan::PK_Signer(deserialized_priv_key, fixed_rng, "Randomized");
            auto deserialized_signature = deserialized_signer.sign_message(msg_ref.data(), msg_ref.size(), fixed_rng);

            result.test_is_eq("signature creation after deserialization",
                              unlock(hash->process(deserialized_signature)),
                              sig_rand_hash);

            // Verification with deserialized Keypair
            Botan::PK_Verifier deserialized_verifier(deserialized_pub_key, params.algorithm_identifier());
            bool verify_success_deserialized = deserialized_verifier.verify_message(
               msg_ref.data(), msg_ref.size(), signature_rand.data(), signature_rand.size());
            result.confirm("verification of valid signature after deserialization", verify_success_deserialized);

            // Verification of invalid signature
            auto broken_sig = Test::mutate_vec(deserialized_signature, this->rng());
            bool verify_fail = deserialized_verifier.verify_message(
               msg_ref.data(), msg_ref.size(), broken_sig.data(), broken_sig.size());
            result.confirm("verification of invalid signature", !verify_fail);

            bool verify_success_after_fail = deserialized_verifier.verify_message(
               msg_ref.data(), msg_ref.size(), signature_rand.data(), signature_rand.size());
            result.confirm("verification of valid signature after broken signature", verify_success_after_fail);
         }

         // Misc
         result.confirm("parameter serialization works", params.to_string() == vars.get_req_str("SphincsParameterSet"));

         return result;
      }
};

class SPHINCS_Plus_Test final : public SPHINCS_Plus_Test_Base {
   public:
      SPHINCS_Plus_Test() : SPHINCS_Plus_Test_Base("pubkey/sphincsplus.vec") {}
};

class SLH_DSA_Test final : public SPHINCS_Plus_Test_Base {
   public:
      SLH_DSA_Test() : SPHINCS_Plus_Test_Base("pubkey/slh_dsa.vec") {}
};

class SPHINCS_Plus_Keygen_Tests final : public PK_Key_Generation_Test {
   public:
      std::vector<std::string> keygen_params() const override {
         const std::vector<std::string> short_test_params = {
            "SphincsPlus-shake-128s-r3.1",
            "SLH-DSA-SHAKE-128s",
            "SphincsPlus-sha2-128f-r3.1",
            "SLH-DSA-SHA2-128f",
         };
         const std::vector<std::string> all_params = {
            "SphincsPlus-shake-128s-r3.1", "SphincsPlus-shake-128f-r3.1", "SphincsPlus-shake-192s-r3.1",
            "SphincsPlus-shake-192f-r3.1", "SphincsPlus-shake-256s-r3.1", "SphincsPlus-shake-256f-r3.1",
            "SLH-DSA-SHAKE-128s",          "SLH-DSA-SHAKE-128f",          "SLH-DSA-SHAKE-192s",
            "SLH-DSA-SHAKE-192f",          "SLH-DSA-SHAKE-256s",          "SLH-DSA-SHAKE-256f",
            "SphincsPlus-sha2-128s-r3.1",  "SphincsPlus-sha2-128f-r3.1",  "SphincsPlus-sha2-192s-r3.1",
            "SphincsPlus-sha2-192f-r3.1",  "SphincsPlus-sha2-256s-r3.1",  "SphincsPlus-sha2-256f-r3.1",
            "SLH-DSA-SHA2-128s",           "SLH-DSA-SHA2-128f",           "SLH-DSA-SHA2-192s",
            "SLH-DSA-SHA2-192f",           "SLH-DSA-SHA2-256s",           "SLH-DSA-SHA2-256f",
         };
         const auto& tested_params = Test::run_long_tests() ? all_params : short_test_params;
         std::vector<std::string> available_params;
         std::copy_if(tested_params.begin(),
                      tested_params.end(),
                      std::back_inserter(available_params),
                      [](const std::string& param) { return Botan::Sphincs_Parameters::create(param).is_available(); });
         return available_params;
      }

      std::string algo_name() const override { return "SLH-DSA"; }

      std::unique_ptr<Botan::Public_Key> public_key_from_raw(std::string_view keygen_params,
                                                             std::string_view /* provider */,
                                                             std::span<const uint8_t> raw_pk) const override {
         return std::make_unique<Botan::SphincsPlus_PublicKey>(raw_pk,
                                                               Botan::Sphincs_Parameters::create(keygen_params));
      }
};

class Generic_SlhDsa_Signature_Tests final : public PK_Signature_Generation_Test {
   public:
      Generic_SlhDsa_Signature_Tests() :
            PK_Signature_Generation_Test(
               "SLH-DSA", "pubkey/slh_dsa_generic.vec", "Instance,Msg,PrivateKey,PublicKey,Valid,Signature", "Nonce") {}

      bool clear_between_callbacks() const override { return false; }

      std::unique_ptr<Botan::Private_Key> load_private_key(const VarMap& vars) override {
         const std::string instance = vars.get_req_str("Instance");
         const std::vector<uint8_t> privkey = vars.get_req_bin("PrivateKey");
         const std::vector<uint8_t> pubkey = vars.get_req_bin("PublicKey");
         auto sk =
            std::make_unique<Botan::SphincsPlus_PrivateKey>(privkey, Botan::Sphincs_Parameters::create(instance));

         if(sk->public_key_bits() != pubkey) {
            throw Test_Error("Invalid SLH-DSA key in test data");
         }

         return sk;
      }

      std::string default_padding(const VarMap& vars) const override {
         return vars.has_key("Nonce") ? "Randomized" : "Deterministic";
      }

      bool skip_this_test(const std::string& /*header*/, const VarMap& vars) override {
         return !Botan::Sphincs_Parameters::create(vars.get_req_str("Instance")).is_available();
      }
};

class Generic_SlhDsa_Verification_Tests final : public PK_Signature_Verification_Test {
   public:
      Generic_SlhDsa_Verification_Tests() :
            PK_Signature_Verification_Test(
               "SLH-DSA", "pubkey/slh_dsa_generic.vec", "Instance,Msg,PrivateKey,PublicKey,Valid,Signature", "Nonce") {}

      bool clear_between_callbacks() const override { return false; }

      std::unique_ptr<Botan::Public_Key> load_public_key(const VarMap& vars) override {
         const std::string instance = vars.get_req_str("Instance");
         const std::vector<uint8_t> pubkey = vars.get_req_bin("PublicKey");

         return std::make_unique<Botan::SphincsPlus_PublicKey>(pubkey, Botan::Sphincs_Parameters::create(instance));
      }

      std::string default_padding(const VarMap& /*vars*/) const override { return ""; }

      bool skip_this_test(const std::string& /*header*/, const VarMap& vars) override {
         return !Botan::Sphincs_Parameters::create(vars.get_req_str("Instance")).is_available();
      }
};

BOTAN_REGISTER_TEST("pubkey", "sphincsplus", SPHINCS_Plus_Test);
BOTAN_REGISTER_TEST("pubkey", "slh_dsa", SLH_DSA_Test);
BOTAN_REGISTER_TEST("pubkey", "slh_dsa_keygen", SPHINCS_Plus_Keygen_Tests);
BOTAN_REGISTER_TEST("pubkey", "slh_dsa_sign_generic", Generic_SlhDsa_Signature_Tests);
BOTAN_REGISTER_TEST("pubkey", "slh_dsa_verify_generic", Generic_SlhDsa_Verification_Tests);

}  // namespace Botan_Tests

#endif  // BOTAN_HAS_SPHINCS_PLUS
